import json
import sys

import litellm
from google.genai.types import UserContent

if sys.platform == "win32":
    import asyncio

    asyncio.set_event_loop_policy(asyncio.WindowsSelectorEventLoopPolicy())
import uvicorn
from fastapi import FastAPI, Body
from google.adk import Runner
from google.adk.artifacts import InMemoryArtifactService
from google.adk.sessions import InMemorySessionService
from google.adk.tools.mcp_tool import MCPToolset
from google.genai import types
from mcp import StdioServerParameters
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware

from gen_pic.agent import root_agent
from gen_pic.tool.custom_tool import get_weather_stateful

litellm._turn_on_debug()
app = FastAPI(
    title="LangChain Server",
    version="1.0",
    description="A simple api server using Langchain's Runnable interfaces",
)
app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
    expose_headers=["*"],
)


async def get_tools_async():
    """Gets tools from MCP Server."""
    tools = MCPToolset(
        connection_params=StdioServerParameters(
            command='npx',
            args=["-y",  # 命令的参数
                  "@modelcontextprotocol/server-filesystem",
                  # 重要提示！将下面的路径更改为你系统上的绝对路径。
                  "E:\\project\\genify\\gen_pic",
                  ],
        )
    )
    # MCP requires maintaining a connection to the local MCP Server.
    # Using exit_stack to clean up server connection before exit.
    return tools


async def get_gen_pic():
    """
"picServers": {
    "command": "java",
    "args": [
        "-Dspring.ai.mcp.server.stdio=true",
        "-Dspring.main.web-application-type=none",
        "-Dlogging.pattern.console=",
        "-jar",
        "D:\\OtherProject\\aitest\\target\\ai-0.0.1-SNAPSHOT.jar"
    ]
"""
    tools = MCPToolset(
        connection_params=StdioServerParameters(
            command='java',
            args=[  # 命令的参数
                "-Dspring.ai.mcp.server.stdio=true",
                "-Dspring.main.web-application-type=none",
                "-Dlogging.pattern.console=",
                "-jar",
                "/mnt/e/project/genify/gen_pic/test/ai-0.0.1-SNAPSHOT.jar"
            ],
        )
    )

    # MCP requires maintaining a connection to the local MCP Server.
    # Using exit_stack to clean up server connection before exit.
    return tools


async def get_agent_async():
    # tools = await get_tools_async()
    #
    # root_agent.tools.extend(await tools.get_tools())
    # root_agent.tools.append(get_weather)
    root_agent.tools.append(get_weather_stateful)
    # root_agent.tools.append(await get_tools_async())
    tool = await get_gen_pic()
    root_agent.tools.append(tool)
    # root_agent.sub_agents.append(greeting_agent)
    # root_agent.sub_agents.append(farewell_agent)
    root_agent.output_key = "last_weather_report"
    return root_agent, tool


session_service = InMemorySessionService()
artifacts_service = InMemoryArtifactService()


async def async_main(query):
    # Define initial state data - user prefers Celsius initially
    initial_state = {
        "user_preference_temperature_unit": "Celsius"
    }
    session = await session_service.get_session(app_name="pic_gen", user_id="traveler0115", session_id="1234567890")
    if not session:
        await session_service.create_session(
            state=initial_state, app_name="pic_gen", user_id="traveler0115", session_id="1234567890"
        )

    agent, tool = await get_agent_async()
    runner = Runner(
        app_name="pic_gen",
        agent=agent,
        artifact_service=artifacts_service,
        session_service=session_service,
    )
    # Prepare the user's message in ADK format
    content = types.Content(role='user', parts=[types.Part(text=query)])
    content = UserContent(query)

    final_response_text = "Agent did not produce a final response."  # Default
    # events_async = runner.run_async(
    #     session_id="112321312312", user_id="traveler0115", new_message=question
    # )

    print(f"<<< Sending message to agent: {content}")
    # async with mcp.stdio_server() as (read_stream, write_stream):
    #     print("MCP 服务器开始握手...")
    async for event in runner.run_async(session_id="1234567890", user_id="traveler0115", new_message=content):
        # You can uncomment the line below to see *all* events during execution
        # print(f"  [Event] Author: {event.author}, Type: {type(event).__name__}, Final: {event.is_final_response()}, Content: {event.content}")
        print(f"收到事件#########：{event}")
        print("\n***********************************************\n")
        # Key Concept: is_final_response() marks the concluding message for the turn.
        # if event.is_final_response():
        #     if event.content and event.content.parts:
        #         print(f"  [Final Response] Content: {event.content.parts}")
        #         # Assuming text response in the first part
        #         final_response_text = event.content.parts[0].text
        #     elif event.actions and event.actions.escalate:  # Handle potential errors/escalations
        #         final_response_text = f"Agent escalated: {event.error_message or 'No specific message.'}"
        #     # Add more checks here if needed (e.g., specific error codes)
        #     break  # Stop processing events once the final response is found

        # print(event)
        author = event.author
        # Uncomment this to see the full event payload
        # print(f"\n[{author}]: {json.dumps(event)}")

        function_calls = [
            e.function_call for e in event.content.parts if e.function_call
        ]
        function_responses = [
            e.function_response for e in event.content.parts if e.function_response
        ]

        if event.content.parts[0].text:
            text_response = event.content.parts[0].text
            final_response_text = text_response
            print(f"event.content.parts#########: \n[{author}]: {text_response}")
            print("\n***********************************************\n")

        if function_calls:
            for function_call in function_calls:
                print(f"function_call#########: \n[{author}]: {function_call.name}( {json.dumps(function_call.args)} )")
                print("\n***********************************************\n")



        elif function_responses:
            for function_response in function_responses:
                function_name = function_response.name
                # Detect different payloads and handle accordingly
                application_payload = function_response.response
                #
                # if function_name == "airbnb_search":
                #     application_payload = application_payload["result"].content[0].text
                print(f"function_response#########: [{author}]: {function_name} responds -> {str(application_payload)}")
                print("\n***********************************************\n")
                final_response_text = str(application_payload)

    print(f"<<< Agent Response : {final_response_text}")
    await tool.close()
    return final_response_text


class QueryRequest(BaseModel):
    query: str


@app.post("/root")
async def root(request: QueryRequest = Body(...)):
    print(f"received reqeust : {request}")
    resp = await run_conversation(request.query)
    return {"message": resp}


async def run_conversation(query: str):
    return await async_main(query)
    # await async_main("What's the weather in London?")  # Expecting the tool's error message
    # print("\n--- Manually Updating State: Setting unit to Fahrenheit ---")
    # try:
    #     # Access the internal storage directly - THIS IS SPECIFIC TO InMemorySessionService for testing
    #     stored_session = session_service.sessions["pic_gen"]["traveler0115"]["1234567890"]
    #     stored_session.state["user_preference_temperature_unit"] = "Fahrenheit"
    #     # Optional: You might want to update the timestamp as well if any logic depends on it
    #     # import time
    #     # stored_session.last_update_time = time.time()
    #     print(
    #         f"--- Stored session state updated. Current 'user_preference_temperature_unit': {stored_session.state['user_preference_temperature_unit']} ---")
    # except KeyError:
    #     print(
    #         f"--- Error: Could not retrieve session '{1234567890}' from internal storage for user '{'traveler0115'}' in app '{'pic_gen'}' to update state. Check IDs and if session was created. ---")
    # except Exception as e:
    #     print(f"--- Error updating internal session state: {e} ---")
    # await async_main("Tell me the weather in New York.")
    # await async_main("Hi")
    # await async_main("Thanks, bye!")
    # print("\n--- Inspecting Final Session State ---")
    # final_session = session_service.sessions["pic_gen"]["traveler0115"]["1234567890"]
    # if final_session:
    #     print(f"Final Preference: {final_session.state.get('user_preference_temperature_unit')}")
    #     print(f"Final Last Weather Report (from output_key): {final_session.state.get('last_weather_report')}")
    #     print(f"Final Last City Checked (by tool): {final_session.state.get('last_city_checked_stateful')}")
    #     # Print full state for detailed view
    #     # print(f"Full State: {final_session.state}")
    # else:
    #     print("\n❌ Error: Could not retrieve final session state.")
    # await async_main("What's the weather in London?")


# if __name__ == "__main__":
#     asyncio.run()


if __name__ == "__main__":
    uvicorn.run(app, host="localhost", port=8000)
